// Copyright (c) 2020-present, INSPUR Co, Ltd. All rights reserved.
// This source code is licensed under Apache 2.0 License.

#include "art_tree.h"
#include "pure_mem/epoche.h"
#include "rocksdb/slice.h"
#include "tree_node.h"
#include <algorithm>
#include <assert.h>
#include <functional>
#include <iostream>

namespace syn_art_fullkey {

Tree::Tree(ILoadKey *loadKey)
    : root_(N::newTreeNode(NTypes::NodeType256, 0, rocksdb::Slice())),
      loadKey_(loadKey) {}
Tree::~Tree() { N::removeAll(root_); }

void Tree::checkAnyChildrenLocked(N *node) const {
  if (N::isLeaf(node))
    return;
  assert(!node->isLocked());

  std::tuple<uint8_t, N *> children[256];
  uint32_t childCount = 0;
  N::getChildrenSmall(node, 0, 255, children, childCount, 256);

  for (uint32_t i = 0; i < childCount; i++) {
    checkAnyChildrenLocked(std::get<1>(children[i]));
  }
}

bool Tree::splitPrefix(uint32_t level, bool keyEnd, uint8_t keyToken,
                       void *value, N *&node, const rocksdb::Slice &samePrefix,
                       rocksdb::Slice &remainPrefix, N *parentNode,
                       uint8_t parentKey) {
  bool needRestart = false;
  node->lockVersionOrRestart(needRestart);
  if (needRestart)
    return false;
  // 1) Create new node which will be parent of node, Set common prefix,
  // level to this node
  N *HighLevelNode = N::newTreeNode(NTypes::NodeType4, level, samePrefix);

  if (!keyEnd) {
    N::insert(HighLevelNode, keyToken, N::setLeaf(value));
  } else { // if key is end . value of key should store into fullkeyLeaf.
    HighLevelNode->setFullKeyLeaf(nullptr, value);
  }
  N *copyNode = N::copyNodeWithNewPrefix(
      node, rocksdb::Slice((remainPrefix.data_ + 1), remainPrefix.size() - 1));
  N::insert(HighLevelNode, remainPrefix[0], copyNode);

  // 3) lockVersionOrRestart, update parentNode to point to the new node, unlock
  parentNode->writeLockOrRestart(needRestart);
  if (needRestart) {
    free(HighLevelNode);
    free(copyNode);
    node->writeUnlock();
    return false;
  }

  N::change(parentNode, parentKey, HighLevelNode);
  parentNode->writeUnlock();
  copyNode->writeUnlock();
  N *oldNode = node;
  node = copyNode;
  oldNode->writeUnlockObsolete(); // delete old version node object.
  rocksdb::DeleteWhileNoRefs::getInstance()->markNodeForDeletion(
      oldNode, rocksdb::DELETION_TYPE_MALLOC);
  return true;
}

bool Tree::insertOrGetLG(const rocksdb::Slice &k, void *newValue,
                         void *&retValue) {

  bool needRestart = true;
  N *node = nullptr;
  N *nextNode = root_;
  N *parentNode = nullptr;
  uint8_t parentKey, nodeKey = 0;
  uint32_t level = 0, nextLevel;
restart:
  level = 0;
  nodeKey = 0;
  nextNode = root_;
  needRestart = false;

  while (true) {
    parentNode = node;
    parentKey = nodeKey;
    node = nextNode;
    auto v = node->getVersion();

    nextLevel = level;
    uint8_t nonMatchingKey;
    rocksdb::Slice remainingPrefix;
    switch (matchPrefix(node, k, nextLevel, remainingPrefix)) {
    case 0: { // 匹配成功，跳出ｓｗｉｔｈ后继续处理
      break;
    }
    case 1:
      if (nextLevel >= k.size()) { // ｋｅｙ的长度不足引起的前缀匹配失败
        if (!splitPrefix(nextLevel, true, 0, newValue, node,
                         rocksdb::Slice(&(k.data_[level]), nextLevel - level),
                         remainingPrefix, parentNode, parentKey))
          goto restart;
      } else {
        if (!splitPrefix(nextLevel, false, k[nextLevel], newValue, node,
                         rocksdb::Slice(&(k.data_[level]), nextLevel - level),
                         remainingPrefix, parentNode, parentKey))
          goto restart;
      }
      retValue = getMinTID(node, 0, true);
      return true; // create new art leaf.

    case -1: // no match
      if (!splitPrefix(nextLevel, false, k[nextLevel], newValue, node,
                       rocksdb::Slice(&(k.data_[level]), nextLevel - level),
                       remainingPrefix, parentNode, parentKey)) {
        goto restart;
      }
      retValue = getMaxTID(node, 255, true);
      return true;
    default:
      assert(false);
      break;
    }

    level = nextLevel;
    if (level == k.size()) {
      while (true) {
        if (node->isObsolete())
          goto restart;
        void *ret = node->getFullKeyLeaf();
        if (ret != nullptr) {
          retValue = ret;
          return false; // key exists, insert failed.
        } else {
          if (node->setFullKeyLeaf(ret, newValue)) {
            retValue = getMinTID(node, 0, false);
            return true;
          }
        }
      }
    }

    nodeKey = k[level];
    nextNode = N::getChild(nodeKey, node);

    if (nextNode == nullptr) {
      node->lockVersionOrRestart(needRestart);
      if (needRestart)
        goto restart;

      N::insertGrow(node, parentNode, parentKey, nodeKey, N::setLeaf(newValue),
                    needRestart);
      node->writeUnlock();

      if (needRestart)
        goto restart;
      if (nodeKey < 255)
        retValue = getMinTID(node, nodeKey + 1, false);
      if (retValue == nullptr && nodeKey > 0) {
        retValue = getMaxTID(node, nodeKey - 1, true);
      }
      return true;
    }

    if (N::isLeaf(nextNode)) {
      void *nextNodeLeafValue = N::getLeaf(nextNode);
      level++;
      node->lockVersionOrRestart(needRestart);
      if (needRestart)
        goto restart;

      rocksdb::Slice key;
      loadKey_->parseTid2Key(nextNodeLeafValue, key);

      uint32_t prefixLength = 0;
      u_int32_t compLength = std::min(key.size(), k.size());
      while ((level + prefixLength) < compLength &&
             key[level + prefixLength] == k[level + prefixLength]) {
        prefixLength++;
      }

      if ((level + prefixLength) == k.size() &&
          k.size() == key.size()) { // find TID of current key
        node->writeUnlock();
        retValue = nextNodeLeafValue;
        return false;
      }

      N *HighLevelNode =
          N::newTreeNode(NTypes::NodeType4, level + prefixLength,
                         rocksdb::Slice(&(k.data_[level]), prefixLength));
      if ((level + prefixLength) >= k.size()) {
        HighLevelNode->setFullKeyLeaf(nullptr, newValue);
      } else {
        N::insert(HighLevelNode, k[level + prefixLength], N::setLeaf(newValue));
      }
      if ((level + prefixLength) >= key.size()) {
        HighLevelNode->setFullKeyLeaf(nullptr, nextNodeLeafValue);
      } else {
        N::insert(HighLevelNode, key[level + prefixLength], nextNode);
      }

      N::change(node, k[level - 1], HighLevelNode);
      node->writeUnlock();

      retValue = nextNodeLeafValue; // nextNode is nearest by current key.
      return true;
    }
    level++;
  }
}

void *Tree::getMinTID(N *node, uint8_t start, bool includeFullKey) const {
  if (node == NULL)
    return nullptr;

  if (N::isLeaf(node)) {
    return N::getLeaf(node);
  }
  if (includeFullKey){
    void *leafValue = node->getFullKeyLeaf();
    if (leafValue != nullptr) {
      return leafValue;
    }
  }

  std::tuple<uint8_t, N *> children[4];
  uint32_t childCount = 0;
  N::getChildrenSmall(node, start, 255, children, childCount, 4);

  for (uint32_t i = 0; i < childCount; i++) {
    void *ret = getMinTID(std::get<1>(children[i]), 0, true);
    if (ret != nullptr)
      return ret;
  }
  return nullptr;
}

void *Tree::getMaxTID(N *node, uint8_t end, bool includeFullKey) const {
  if (node == NULL)
    return nullptr;
  if (N::isLeaf(node)) {
    return N::getLeaf(node);
  }
  std::tuple<uint8_t, N *> children[4];
  uint32_t childCount = 0;
  N::getChildrenLarge(node, 0, end, children, childCount, 4);

  for (uint32_t i = 0; i < childCount; i++) {
    void *ret = getMaxTID(std::get<1>(children[childCount - 1 - i]), 255, true);
    if (ret != nullptr)
      return ret;
  }
  if (includeFullKey)
    return node->getFullKeyLeaf();
  return nullptr;
}

void *Tree::seek(const rocksdb::Slice &k, bool equal, bool greater) const {
  // checkAnyChildrenLocked(root_);
  return seekChildren(root_, k, 0, true, equal, greater);
}

void *Tree::seekChildren(N *node, const rocksdb::Slice &k, uint32_t level, bool prePathMatch,
                         bool equal, bool greater) const {
  if (node == nullptr) {
    return nullptr;
  }

  if (N::isLeaf(node)) { // two scene: path completely match or current node
                         // path compressed
    void *tid = N::getLeaf(node);
    // path completely match ,then TID is the one
    if (prePathMatch && equal && greater && k.size() <= level){
      return tid;
    }
    if (!prePathMatch && greater){
      return tid;
    }

    rocksdb::Slice key;
    loadKey_->parseTid2Key(tid, key);
    int compp = key.compare(k);
    if (compp > 0 && greater) return tid;
    if (compp == 0 && ((prePathMatch && equal) || (!prePathMatch && greater))){
      return tid;
    }

    return nullptr;
  }
  uint32_t nextLevel = level;
  rocksdb::Slice remainPrefix;
  switch (Tree::matchPrefix(node, k, nextLevel, remainPrefix)) {
  // path prefix in current node < key, so all children of current node < key.
  case -1:
    return nullptr;
  // path prefix in current node > key, so that the smallest leaf of current
  // node > key.
  case 1: {
    if (!greater)
      return nullptr;
    return getMinTID(node, 0, true);
  }
  case 0:
    break;
  }

  {
    if (node->getFullKeyLeaf() != nullptr) {
      if ((prePathMatch && equal && k.size() == nextLevel) 
          || (!prePathMatch && greater))
        return node->getFullKeyLeaf();
    }

    level = nextLevel;

    // competely match prefix
    uint8_t start = 0, end = 255;
    if (k.size() > level) {
      start = k[level];
    }

    if (!greater) {
      end = start;
    }

    std::tuple<uint8_t, N *> children[4];
    uint32_t childCount = 0;
    N::getChildrenSmall(node, start, end, children, childCount, 4);

    void *ret;
    for (uint32_t i = 0; i < childCount; i++) {
      bool tokenMatch = (k.size() <= level || std::get<0>(children[i]) == start);
      if (prePathMatch && tokenMatch){
          ret = seekChildren(std::get<1>(children[i]), k, level + 1, true, equal, greater);
      }else{
        ret = getMinTID(std::get<1>(children[i]), 0, true);
      }

      if (ret != nullptr) {
        return ret;
      }
    }
  }

  return nullptr;
}

int Tree::matchPrefix(N *n, const rocksdb::Slice &k, uint32_t &level,
                      rocksdb::Slice &remainPrefix) {
  Prefix &p = N::getPrefix(n);
  int ret = 0;
  if (p.length_ > 0) {
    uint32_t i = 0;
    for (i = 0; i < p.length_; ++i, ++level) {
      if (level >= k.size()) {
        ret = 1;
        break;
      }
      if (p.prefix_[i] > k[level]) {
        ret = 1;
        break;
      } else if (p.prefix_[i] < k[level]) {
        ret = -1;
        break;
      }
    }
    if (i != p.length_) {
      remainPrefix.data_ = p.prefix_ + i;
      remainPrefix.size_ = p.length_ - i;
    }
  }
  return ret;
}

} // namespace syn_art_fullkey